import os
import re
import argparse
import pickle
import numpy as np
import wandb
import torch
import torch.optim as optim

from extract_esm2_representation import prepare_seq, extract_esm2_representation, load_pretrained_esm
from esmfold_modified_v1 import ContinousLinkerModel

from utils.data_util import get_seqs, get_distogram
from utils.crop import crop
from utils.misc import seed_everything, random_sampling, add_linker_to_data, load_esm2_seq_representation
from utils.evaluate_util import evaulate_contact_predicsion_from_distogram, compute_inter_chain_contact_precision_from_distogram
from utils.logger import Logger
logger = Logger.logger
import warnings
warnings.filterwarnings('ignore')


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_mode', type=str, default='test', choices=['train', 'test'])
    parser.add_argument('--backbone', type=str, default='esmfold_v1', choices=['esmfold_v1'])
    parser.add_argument('--data_type', type=str, default='heterodimer', choices=['heterodimer'])
    parser.add_argument('--data_dir', type=str, default='data/')
    parser.add_argument('--backbone_dir', type=str, default='checkpoint/esmfold', help='directory of backbone model')
    parser.add_argument('--linker_len', type=int, default=25)
    parser.add_argument('--inter_weight', type=float, default=4, help='weight for the inter-chain ditogram loss')
    parser.add_argument('--chain_linker', type=str, default=None, help='the nearset aa linker of learnt linker')
    parser.add_argument('--num_recycles', type=int, default=1)
    parser.add_argument('--residue_gap', type=int, default=0)
    parser.add_argument('--precompute_esm2', default=False, action='store_true',help='whether to use the precomputed esm2 seq representation')

    # training
    parser.add_argument('--dynamic', default=False, action='store_true', help="whether to dynamically crop chains")
    parser.add_argument('--crop', default=False, action='store_true', help="whether to crop chains")
    parser.add_argument('--no_test', default=False, action='store_true', help="if true, don't test on test set during training")
    parser.add_argument('--crop_size', type=int, default=200, help='multi-chain cropping, the train seq len')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--clipnorm', type=float, default=0.1, help='clip the gradient norm')
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.0, help='weight_decay in optimizer')
    parser.add_argument('--temp', type=float, default=0.07, help='temperature in loss')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--print_every", type=int, default=200, help='log_interval')
    parser.add_argument("--patience", type=int, default=10, help='patience for early stopping')
    parser.add_argument('--pt_saving_dir', type=str, default='checkpoint/linker/')
    parser.add_argument('--cp_every', type=int, default=1, help="save checkpont every k epoch")
    parser.add_argument('--cp_file', type=str, default=None, help="checkpoint file")
    parser.add_argument('--few_shot', default=False, action='store_true', help="whether to train in a few shot samples")
    parser.add_argument('--num_samples', type=int, default=128, help="number of samples for few shot learning")

    # test only
    parser.add_argument('--pdb_case', default=None, type=str, help="the case we want to study")
    parser.add_argument('--output_pdb', default=False, action='store_true', help="whether to output pdb file")
    parser.add_argument('--keep_linker_in_output', default=False, action='store_true', help="if true, return output with linker")
    parser.add_argument('--save_logits', default=False, action='store_true', help="whether to save distogram logits")
    parser.add_argument('--logit_path', type=str, default=None)
    parser.add_argument('--test_name', type=str, default='test', help='valid, test, test2, vhvl68, vhvl171')
    parser.add_argument('--model_name', type=str, default=None, help='name of the trained model for prediction')
    args = parser.parse_args()
    return args


class LinkerTuning:
    def __init__(self, args):
        self.args = args
        if args.backbone == 'esmfold_v1':
            self.model = ContinousLinkerModel(args)
            self.min_bin, self.max_bin, self.num_bins = 2.3125, 21.6875, 64

        self.data = {} 
        if self.args.model_mode=='train':
            for data_mode in ['train', 'valid', 'test']:
                if data_mode=='train' and args.crop:
                    seq_path = os.path.join(args.data_dir, 'train_crop'+str(args.crop_size)+'_profiling.pickle')
                    dist_path = os.path.join(args.data_dir, 'train_crop'+str(args.crop_size)+'_distance_map.pickle')
                else:
                    seq_path = os.path.join(args.data_dir, data_mode+'_profiling.csv')
                    dist_path = os.path.join(args.data_dir, data_mode+'_distance_map.pickle')
                self.prepare_data(seq_path, dist_path, data_mode, self.args.precompute_esm2)
            logger.info('train_size={}, valid_size={}, test_size={}'.format(
                        len(self.data['train']['distos'].keys()),
                        len(self.data['valid']['distos'].keys()), 
                        len(self.data['test']['distos'].keys())))

       
    def train(self):

        cp_ep = 0
        if self.args.cp_file is not None:
            # continue training from checkpoint
            model_path = os.path.join(self.args.pt_saving_dir, self.args.cp_file)
            self.load_model(model_path)
            cp_ep = int(re.findall(r'\d+', self.args.cp_file)[-1])+1
            logger.info('Continue to train from checkpoint {}'.format(cp_ep-1))
        else:
            self.model.esmfold.initialize_linker_embed()

        self.model.convert_linker_to_aa_seq()
        self.model.freeze_esmfold()
        self.check_trainable_parameters()
        seed_everything(self.args)

        project = 'esmfold-pet-' + self.args.data_type 
        if self.args.few_shot:
            project = project + '-fewshot'
        wandb.init(project=project)
        wandb.config=self.args

        # optimizer
        optimizer = optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay)
        total_steps = int(len(self.data['train']['distos'].keys())/self.args.batch_size) * self.args.num_epochs
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=0.1*self.args.lr)

        if self.args.crop and self.args.dynamic:
            esm2 = load_pretrained_esm(model_name='esmfold_3B_v1.pt', model_dir=self.args.backbone_dir)

        # for early stopping
        counter, min_valid_loss, best_checkpoint = 0, 10000000, 0
        
        for ep in range(cp_ep, cp_ep+self.args.num_epochs):
            logger.info('Train: epoch {} '.format(ep))
            self.model.train()
            optimizer.zero_grad()
            
            if ep>0 and self.args.crop and self.args.dynamic:

                logger.info('Crop and precompute esm2 seq rep')
                seq_path = os.path.join(args.data_dir, 'train_profiling.csv')
                dist_path = os.path.join(args.data_dir, 'train_distance_map.pickle')
                crop(seq_path, 
                    dist_path, 
                    crop_size=self.args.crop_size, 
                    spatial_crop_prob=0.5, 
                    cb_cb_threshold=10, 
                    seed=self.args.seed+ep)

                esm2 = esm2.eval().cuda()
                seq_path = os.path.join(args.data_dir, 'train_crop'+str(self.args.crop_size)+'_profiling.pickle')
                crop_seqs = prepare_seq(seq_path, max_len=None)
                extract_esm2_representation(esm2, crop_seqs, 
                                        data_mode='train', 
                                        data_dir=self.args.data_dir, 
                                        linker_len=self.args.linker_len, 
                                        crop=self.args.crop, 
                                        crop_size=self.args.crop_size, 
                                        renew=True)

                # load cropped train data
                seq_path = os.path.join(args.data_dir, 'train_crop'+str(self.args.crop_size)+'_profiling.pickle')
                dist_path = os.path.join(args.data_dir, 'train_crop'+str(self.args.crop_size)+'_distance_map.pickle')
                self.prepare_data(seq_path, dist_path, data_mode='train', precompute=self.args.precompute_esm2)
                esm2 = esm2.cpu()

            # random shuffle data
            train_names = self.shuffle_data(ep)

            seq_rep = None
            epoch_loss = 0
            steps_loss = 0
            for i, name in enumerate(train_names):
                seq = self.data['train']['seqs'][name]
                if self.args.precompute_esm2:
                    seq_rep = self.data['train']['reps'][name]

                labels = torch.tensor(self.data['train']['distos'][name], dtype=torch.int64).unsqueeze(0).cuda()
                length = self.data['train']['lengths'][name]
                
                _, loss, _ = self.model(seq, seq_rep=seq_rep, labels=labels, lengths=length) 

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clipnorm)
                optimizer.step()

                loss_value = loss.cpu().detach().numpy()
                epoch_loss += loss_value
                steps_loss += loss_value
                if (i+1) % args.print_every == 0:
                    logger.info("Train: Epoch {}, Step {}, loss={:.4f}".format(ep, i, steps_loss/args.print_every))
                    wandb.log({"train_loss (steps)": steps_loss/args.print_every})
                    steps_loss = 0
    
            logger.info("Train: Epoch {} end, loss={:.4f}".format(ep, epoch_loss/(i+1)))
            wandb.log({"train_loss (epoch)": epoch_loss/(i+1)})

            # save model checkpoint
            if (ep+1) % self.args.cp_every == 0:
                dir_path = os.path.join(args.pt_saving_dir, args.data_type)
                if not os.path.exists(dir_path):
                    os.makedirs(dir_path)
                save_path= "{}/checkpoint_{}.pkl".format(dir_path, ep)
                self.save_model(save_path)
                logger.info('Model save to {}'.format(save_path))

                self.model.convert_linker_to_aa_seq()

                loss, precision = self.evaluate(mode='valid', print_metrics=False)
                logger.info("Valid: Epoch {} end, distogram logloss={:.4f}, contact precision={:.4f}\n".format(ep, loss, precision))
                wandb.log({"valid_distogram_loss": loss, "valid_topL5_contact_precision": precision})
                if loss > min_valid_loss:
                    counter += 1
                    if counter >= self.args.patience:
                        logger.info('No improvent for {} epochs, the best checkpiont is {}'.format(counter, best_checkpoint))
                        break
                else:
                    min_valid_loss = loss
                    best_checkpoint = ep
                    counter = 0
                
                # evaluate on test set 
                if not self.args.no_test: 
                    loss, precision = self.evaluate(mode='test', print_metrics=False)
                    logger.info("Test: Epoch {} end, distogram logloss={:.4f}, contact precision={:.4f}\n".format(ep, loss, precision))
                    wandb.log({"test_distogram_loss": loss, "test_topL5_contact_precision": precision})
        wandb.finish()

    
    @torch.no_grad()
    def predict(self, mode='valid'):
        self.model.eval()
        logit_dict, loss_dict, pdb_dict = {}, {}, {}
        seq_rep = None
        for name, seq in self.data[mode]['seqs'].items():
            if self.args.precompute_esm2:
                seq_rep = self.data[mode]['reps'][name]
            try:    
                labels = torch.tensor(self.data[mode]['distos'][name], dtype=torch.int64).unsqueeze(0).cuda()
            except KeyError:
                labels = None
            length = self.data[mode]['lengths'][name]
            logits, loss, predicted_pdb = self.model(seq, seq_rep=seq_rep, labels=labels, lengths=length)
            logit_dict[name] = logits.cpu().squeeze()  # (seqlen, seqlen, 64)
            if labels is not None and loss is not None:
                loss_dict[name] = loss.cpu()
            pdb_dict[name] = predicted_pdb
        return logit_dict, loss_dict, pdb_dict

 
    def evaluate(self, mode='valid', print_metrics=True):
        """ 
        use inter-chain contact precision for model selection   
        """ 
        if self.num_bins == 64:
            upper_bound = 19
        else:
            upper_bound = 5

        if self.args.model_mode == 'test':
            seq_path = os.path.join(args.data_dir, mode+'_profiling.csv')
            dist_path = os.path.join(args.data_dir, mode+'_distance_map.pickle')
            if not os.path.exists(dist_path):
                dist_path = None
            self.prepare_data(seq_path, dist_path, mode, self.args.precompute_esm2)
            logger.info('test_size={}'.format(len(self.data[mode]['seqs'].keys())))
        
        disto_logits, disto_loss, pdb_dict = self.predict(mode)   
       
        if self.args.save_logits:
            self.save_predicted_logits(disto_logits, mode)

        if self.args.output_pdb:
            self.save_predicted_pdbs(pdb_dict, mode)

        if 'distos' in self.data[mode] and self.args.pdb_case is None:
            disto_logloss = np.mean(list(disto_loss.values()))
            if print_metrics:
                evaulate_contact_predicsion_from_distogram(disto_logits, self.data[mode]['distos'], self.data[mode]['lengths'], 
                                                        disto_loss, upper_bound)
            else:
                contact_precision = compute_inter_chain_contact_precision_from_distogram(disto_logits, self.data[mode]['distos'], 
                                                        self.data[mode]['lengths'], upper_bound)
                return disto_logloss, contact_precision
        
                
    def prepare_data(self, seq_path, dist_path, data_mode, precompute=True):
        
        sequences = get_seqs(seq_path)
        if dist_path is not None:
            with open(dist_path, mode='rb') as f:
                dist_maps = pickle.load(f) 
            distograms = get_distogram(dist_maps, min_bin=self.min_bin, max_bin=self.max_bin, num_bins=self.num_bins)
        else:
            distograms = None

        if self.args.pdb_case is not None:
            pdb_case_list = self.args.pdb_case.split(',')
            sequences = {name:sequences[name] for name in pdb_case_list}
            distograms = {name:distograms[name] for name in pdb_case_list}

        self.data[data_mode] = add_linker_to_data(sequences, distograms, linker_len=self.args.linker_len, folding=True, max_len=None)
        
        if self.args.few_shot and data_mode == 'train':
            self.data[data_mode] = random_sampling(self.data[data_mode], n=self.args.num_samples, seed=self.args.seed)
        
        if precompute:
            seq_reps = load_esm2_seq_representation(self.args.data_dir, data_mode, self.args.linker_len, 
                                                    crop=self.args.crop, crop_size=self.args.crop_size)
            self.data[data_mode]['reps'] = {name: seq_reps[name] for name in self.data[data_mode]['seqs']}
        

    def shuffle_data(self, ep):
        # random shuffle training data
        names = list(self.data['train']['seqs'].keys())
        np.random.seed(self.args.seed+ep)
        np.random.shuffle(names)
        return names

    
    def save_model(self, save_path):
        """save the linker embeddings
        """
        linker_embeds = self.model.esmfold.linker_embedding.weight.data.cpu()
        weights = {"linker_embeds": linker_embeds}
        with open(save_path, mode='wb') as f:
            pickle.dump(weights, f)


    def load_model(self, model_path):
        """load linker embeddings
        """
        device = next(self.model.parameters()).device
        with open(model_path, mode='rb') as f:
            weights = pickle.load(f) 
        self.model.esmfold.linker_embedding.weight.data = weights['linker_embeds'].to(device)


    def check_trainable_parameters(self):
        parameters = [p[0] for p in self.model.named_parameters() if p[1].requires_grad]
        logger.info(parameters)


    def save_predicted_logits(self, disto_logits, mode):
        if not os.path.exists(self.args.logit_path):
            os.makedirs(self.args.logit_path)
        if self.args.pdb_case is not None:
            save_path = "{}{}_{}_{}_rec{}_gap{}_{}.pickle".format(self.args.logit_path, self.args.data_type, mode,
                    self.args.model_name, self.args.num_recycles, self.args.residue_gap, self.args.pdb_case[:4])
        else:
            save_path = "{}{}_{}_{}_rec{}_gap{}.pickle".format(self.args.logit_path, self.args.data_type, mode,
                        self.args.model_name, self.args.num_recycles, self.args.residue_gap)
            
        with open(save_path, mode='wb') as f:
            pickle.dump(disto_logits, f) 
        logger.info('disto logits save to {}'.format(save_path))


    def save_predicted_pdbs(self, pdb_dict, mode):
        dir_path = "{}/{}/{}/{}/{}_rec{}_gap{}".format(args.pt_saving_dir, 'pdbs', self.args.data_type, self.args.model_name, 
                                                        mode, self.args.num_recycles, self.args.residue_gap)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

        for name, predicted_pdb in pdb_dict.items():
            save_path= "{}/{}.pdb".format(dir_path, name)
            with open(save_path, "w") as f:
                f.write(predicted_pdb)



if __name__ == '__main__':

    args = get_args()
    logger.info(args)
    ptuning= LinkerTuning(args)
    if args.model_mode == 'train':
        ptuning.train()
    else:
        model_path = os.path.join(args.pt_saving_dir, args.cp_file)
        ptuning.load_model(model_path)
        test_list = args.test_name.split(',')
        for mode in test_list:
            logger.info("\n{}: ".format(mode))
            ptuning.evaluate(mode=mode, print_metrics=True)
